// Copyright Contributors to the Open Shading Language project.
// SPDX-License-Identifier: BSD-3-Clause
// https://github.com/imageworks/OpenShadingLanguage

#include <algorithm>
#include <vector>
#include <string>
#include <sstream>
#include <functional>
#ifndef NDEBUG
#include <atomic>
#endif

#include "osl_pvt.h"
#include "oslcomp_pvt.h"

#include <OpenImageIO/filesystem.h>
#include <OpenImageIO/strutil.h>
#include <OpenImageIO/atomic.h>
namespace Strutil = OIIO::Strutil;

OSL_NAMESPACE_ENTER

namespace pvt {   // OSL::pvt


#ifndef NDEBUG
// When in DEBUG mode, track the number of AST nodes of each type that
// are allocated and remaining, and at program exit print a message about
// any leaked nodes.
namespace {
std::atomic<int> node_counts[ASTNode::_last_node];
std::atomic<int> node_counts_peak[ASTNode::_last_node];

class ScopeExit {
public:
    typedef std::function<void()> Task;
    explicit ScopeExit (Task&& task) : m_task(std::forward<Task>(task)) {}
    ~ScopeExit () { m_task(); }
private:
    Task m_task;
};

ScopeExit print_node_counts ([](){
    for (int i = 0; i < ASTNode::_last_node; ++i)
        if (node_counts[i] > 0)
            Strutil::printf ("ASTNode type %2d: %5d   (peak %5d)\n",
                             i, node_counts[i], node_counts_peak[i]);
});
}
#endif



ASTNode::ref
reverse (ASTNode::ref list)
{
    ASTNode::ref new_list;
    while (list) {
        ASTNode::ref next = list->next();
        list->m_next = new_list;
        new_list = list;
        list = next;
    }
    return new_list;
}



ASTNode::ASTNode (NodeType nodetype, OSLCompilerImpl *compiler) 
    : m_nodetype(nodetype), m_compiler(compiler),
      m_sourcefile(compiler->filename()),
      m_sourceline(compiler->lineno()), m_op(0), m_is_lvalue(false)
{
#ifndef NDEBUG
    node_counts[nodetype] += 1;
    node_counts_peak[nodetype] += 1;
#endif
}



ASTNode::ASTNode (NodeType nodetype, OSLCompilerImpl *compiler, int op,
                  ASTNode *a)
    : m_nodetype(nodetype), m_compiler(compiler),
      m_sourcefile(compiler->filename()),
      m_sourceline(compiler->lineno()), m_op(op), m_is_lvalue(false)
{
    addchild (a);
#ifndef NDEBUG
    node_counts[nodetype] += 1;
    node_counts_peak[nodetype] += 1;
#endif
}



ASTNode::ASTNode (NodeType nodetype, OSLCompilerImpl *compiler, int op)
    : m_nodetype(nodetype), m_compiler(compiler),
      m_sourcefile(compiler->filename()),
      m_sourceline(compiler->lineno()), m_op(op), m_is_lvalue(false)
{
#ifndef NDEBUG
    node_counts[nodetype] += 1;
    node_counts_peak[nodetype] += 1;
#endif
}



ASTNode::ASTNode (NodeType nodetype, OSLCompilerImpl *compiler, int op,
                  ASTNode *a, ASTNode *b)
    : m_nodetype(nodetype), m_compiler(compiler),
      m_sourcefile(compiler->filename()),
      m_sourceline(compiler->lineno()), m_op(op), m_is_lvalue(false)
{
    addchild (a);
    addchild (b);
#ifndef NDEBUG
    node_counts[nodetype] += 1;
    node_counts_peak[nodetype] += 1;
#endif
}



ASTNode::ASTNode (NodeType nodetype, OSLCompilerImpl *compiler, int op,
                  ASTNode *a, ASTNode *b, ASTNode *c)
    : m_nodetype(nodetype), m_compiler(compiler),
      m_sourcefile(compiler->filename()),
      m_sourceline(compiler->lineno()), m_op(op), m_is_lvalue(false)
{
    addchild (a);
    addchild (b);
    addchild (c);
#ifndef NDEBUG
    node_counts[nodetype] += 1;
    node_counts_peak[nodetype] += 1;
#endif
}



ASTNode::ASTNode (NodeType nodetype, OSLCompilerImpl *compiler, int op,
                  ASTNode *a, ASTNode *b, ASTNode *c, ASTNode *d)
    : m_nodetype(nodetype), m_compiler(compiler),
      m_sourcefile(compiler->filename()),
      m_sourceline(compiler->lineno()), m_op(op), m_is_lvalue(false)
{
    addchild (a);
    addchild (b);
    addchild (c);
    addchild (d);
#ifndef NDEBUG
    node_counts[nodetype] += 1;
    node_counts_peak[nodetype] += 1;
#endif
}



ASTNode::~ASTNode ()
{
    // For sufficiently deep trees, the recursive deletion of nodes could
    // overflow the stack. So do it sequentially.
    while (m_next) {
        // Currently:
        //     m_next -->  A  --> B --> ...
        ref n = m_next;
        // Now:
        //     n --> A --> B --> ...
        //     m_next -->  A  --> B --> ...
        m_next = n->m_next;
        // Now:
        //     n --> A --> B --> ...
        //     m_next -->  B --> ...
        n->m_next.reset();
        // Now:
        //     n --> A
        //     m_next -->  B --> ...
        // When we loop, n will exit scope and we will have
        //     m_next --> B --> ...
        // A will have been freed, next time through the loop we will free
        // B, and there was no recursion.
    }

#ifndef NDEBUG
    node_counts[nodetype()] -= 1;
#endif
}



void
ASTNode::error_impl (string_view msg) const
{
    m_compiler->errorf(sourcefile(), sourceline(), "%s", msg);
}



void
ASTNode::warning_impl (string_view msg) const
{
    m_compiler->warningf(sourcefile(), sourceline(), "%s", msg);
}



void
ASTNode::info_impl (string_view msg) const
{
    m_compiler->infof(sourcefile(), sourceline(), "%s", msg);
}



void
ASTNode::message_impl (string_view msg) const
{
    m_compiler->messagef(sourcefile(), sourceline(), "%s", msg);
}



void
ASTNode::print (std::ostream &out, int indentlevel) const
{
    indent (out, indentlevel);
    out << "(" << nodetypename() << " : "
        << "    (type: " << typespec().string() << ") "
        << (opname() ? opname() : "") << "\n";
    printchildren (out, indentlevel);
    indent (out, indentlevel);
    out << ")\n";
}



void
ASTNode::printchildren (std::ostream &out, int indentlevel) const
{
    for (size_t i = 0;  i < m_children.size();  ++i) {
        if (! child(i))
            continue;
        indent (out, indentlevel);
        if (childname(i))
            out << "  " << childname(i);
        else
            out << "  child" << i;
        out << ": ";
        if (typespec() != TypeSpec() && ! child(i)->next())
            out << " (type: " << typespec().string() << ")";
        out << "\n";
        printlist (out, child(i), indentlevel+1);
    }
}



const char *
ASTNode::type_c_str (const TypeSpec &type) const
{
    return m_compiler->type_c_str (type);
}



void
ASTNode::list_to_vec (const ref &A, std::vector<ref> &vec)
{
    vec.clear ();
    for (ref node = A; node; node = node->next())
        vec.push_back (node);
}



ASTNode::ref
ASTNode::vec_to_list (std::vector<ref> &vec)
{
    if (vec.size()) {
        for (size_t i = 0;  i < vec.size()-1;  ++i)
            vec[i]->m_next = vec[i+1];
        vec[vec.size()-1]->m_next = NULL;
        return vec[0];
    } else {
        return ref();
    }
}



std::string
ASTNode::list_to_types_string (const ASTNode *node)
{
    std::ostringstream result;
    for (int i = 0; node; node = node->nextptr(), ++i) {
        if (i)
            result << ", ";
        result << node->typespec();
    }
    return result.str();
}



ASTshader_declaration::ASTshader_declaration (OSLCompilerImpl *comp,
                                int stype, ustring name, ASTNode *form,
                                ASTNode *stmts, ASTNode *meta)
    : ASTNode (shader_declaration_node, comp, stype, meta, form, stmts),
      m_shadername(name)
{
    // Double check some requirements of shader parameters
    for (ASTNode *arg = form;  arg;  arg = arg->nextptr()) {
        OSL_ASSERT (arg->nodetype() == variable_declaration_node);
        ASTvariable_declaration *v = (ASTvariable_declaration *)arg;
        if (! v->init())
            v->errorf("shader parameter '%s' requires a default initializer",
                      v->name());
        if (v->is_output() && v->typespec().is_unsized_array())
            v->errorf("shader output parameter '%s' can't be unsized array",
                      v->name());
    }
}



const char *
ASTshader_declaration::childname (size_t i) const
{
    static const char *name[] = { "metadata", "formals", "statements" };
    return name[i];
}



void
ASTshader_declaration::print (std::ostream &out, int indentlevel) const
{
    indent (out, indentlevel);
    out << "(" << nodetypename() << " " << shadertypename()
              << " \"" << m_shadername << "\"\n";
    printchildren (out, indentlevel);
    indent (out, indentlevel);
    out << ")\n";
}



string_view
ASTshader_declaration::shadertypename () const
{
    return OSL::pvt::shadertypename ((ShaderType)m_op);
}



ASTfunction_declaration::ASTfunction_declaration (OSLCompilerImpl *comp,
                             TypeSpec type, ustring name,
                             ASTNode *form, ASTNode *stmts, ASTNode *meta,
                             int sourceline_start)
    : ASTNode (function_declaration_node, comp, 0, meta, form, stmts),
      m_name(name), m_sym(NULL), m_is_builtin(false)
{
    // Some trickery -- the compiler's idea of the "current" source line
    // is the END of the function body, so if a hint was passed about the
    // start of the declaration, substitute that.
    if (sourceline_start >= 0)
        m_sourceline = sourceline_start;

    if (Strutil::starts_with (name, "___"))
        errorf("\"%s\" : sorry, can't start with three underscores", name);

    // Get a pointer to the first of the existing symbols of that name.
    Symbol *existing_syms = comp->symtab().clash (name);
    if (existing_syms && existing_syms->symtype() != SymTypeFunction) {
        errorf("\"%s\" already declared in this scope as a %s",
               name, existing_syms->typespec());
        // FIXME -- print the file and line of the other definition
        existing_syms = NULL;
    }

    // Build up the argument signature for this declared function
    m_typespec = type;
    std::string argcodes = oslcompiler->code_from_type (m_typespec);
    for (ASTNode *arg = form;  arg;  arg = arg->nextptr()) {
        const TypeSpec &t (arg->typespec());
        if (t == TypeSpec() /* UNKNOWN */) {
            m_typespec = TypeDesc::UNKNOWN;
            return;
        }
        argcodes += oslcompiler->code_from_type (t);
        OSL_ASSERT (arg->nodetype() == variable_declaration_node);
        ASTvariable_declaration *v = (ASTvariable_declaration *)arg;
        if (v->init())
            v->errorf("function parameter '%s' may not have a default initializer.",
                      v->name());
    }

    // Allow multiple function declarations, but only if they aren't the
    // same polymorphic type in the same scope.
    if (stmts) {
        std::string err;
        int current_scope = oslcompiler->symtab().scopeid();
        for (FunctionSymbol *f = static_cast<FunctionSymbol *>(existing_syms);
             f; f = f->nextpoly()) {
            if (f->scope() == current_scope && f->argcodes() == argcodes) {
                // If the argcodes match, only one should have statements.
                // If there is no ASTNode for the poly, must be a builtin, and
                // has 'implicit' statements.
                auto other = static_cast<ASTfunction_declaration*>(f->node());
                if (!other || (other->statements() || other->is_builtin())) {
                    if (err.empty()) {
                        err = Strutil::sprintf("Function '%s %s (%s)' redefined "
                                              "in the same scope\n"
                                              "  Previous definitions:", type,
                                              name, list_to_types_string(form));
                    }
                    err += "\n    ";
                    if (other) {
                        err += Strutil::sprintf("%s:%d",
                                    OIIO::Filesystem::filename(other->sourcefile().string()),
                                    other->sourceline());
                    } else
                        err += "built-in";
                }
            }
        }
        if (!err.empty())
            warningf("%s", err);
    }


    m_sym = new FunctionSymbol (name, type, this);
    func()->nextpoly ((FunctionSymbol *)existing_syms);

    func()->argcodes (ustring (argcodes));
    oslcompiler->symtab().insert (m_sym);

    // Typecheck it right now, upon declaration
    typecheck (typespec ());
}



void
ASTfunction_declaration::add_meta (ref metaref)
{
    for (ASTNode *meta = metaref.get();  meta;  meta = meta->nextptr()) {
        OSL_ASSERT (meta->nodetype() == ASTNode::variable_declaration_node);
        const ASTvariable_declaration *metavar = static_cast<const ASTvariable_declaration *>(meta);
        Symbol *metasym = metavar->sym();
        if (metasym->name() == "builtin") {
            m_is_builtin = true;
            if (func()->typespec().is_closure())  { // It is a builtin closure
                // Force keyword arguments at the end
                func()->argcodes(ustring(std::string(func()->argcodes().c_str()) + "."));
            }
            // For built-in functions, if any of the params are output,
            // also automatically mark it as readwrite_special_case.
            for (ASTNode *f = formals().get(); f; f = f->nextptr()) {
                OSL_ASSERT (f->nodetype() == variable_declaration_node);
                ASTvariable_declaration *v = (ASTvariable_declaration *)f;
                if (v->is_output())
                    func()->readwrite_special_case (true);
            }
        }
        else if (metasym->name() == "derivs")
            func()->takes_derivs (true);
        else if (metasym->name() == "printf_args")
            func()->printf_args (true);
        else if (metasym->name() == "texture_args")
            func()->texture_args (true);
        else if (metasym->name() == "rw")
            func()->readwrite_special_case (true);
    }
}



const char *
ASTfunction_declaration::childname (size_t i) const
{
    static const char *name[] = { "metadata", "formals", "statements" };
    return name[i];
}



void
ASTfunction_declaration::print (std::ostream &out, int indentlevel) const
{
    indent (out, indentlevel);
    out << nodetypename() << " " << m_sym->mangled();
    if (m_sym->scope())
        out << " (" << m_sym->name()
                  << " in scope " << m_sym->scope() << ")";
    out << "\n";
    printchildren (out, indentlevel);
}



ASTvariable_declaration::ASTvariable_declaration (OSLCompilerImpl *comp,
                                                  const TypeSpec &type,
                                                  ustring name, ASTNode *init,
                                                  bool isparam, bool ismeta,
                                                  bool isoutput, bool initlist,
                                                  int sourceline_start)
    : ASTNode (variable_declaration_node, comp, 0, init, NULL /* meta */),
      m_name(name), m_sym(NULL),
      m_isparam(isparam), m_isoutput(isoutput), m_ismetadata(ismeta),
      m_initlist(initlist)
{
    // Some trickery -- the compiler's idea of the "current" source line
    // is the END of the declaration, so if a hint was passed about the
    // start of the declaration, substitute that.
    if (sourceline_start >= 0)
        m_sourceline = sourceline_start;

    if (m_initlist && init) {
        // Typecheck the init list early.
        OSL_ASSERT (init->nodetype() == compound_initializer_node);
        static_cast<ASTcompound_initializer*>(init)->typecheck(type);
    }

    m_typespec = type;
    Symbol *f = comp->symtab().clash (name);
    if (f  &&  ! m_ismetadata) {
        std::string e = Strutil::sprintf ("\"%s\" already declared in this scope", name.c_str());
        if (f->node()) {
            std::string filename = OIIO::Filesystem::filename(f->node()->sourcefile().string());
            e += Strutil::sprintf ("\n\t\tprevious declaration was at %s:%d",
                                   filename, f->node()->sourceline());
        }
        if (f->scope() == 0 && f->symtype() == SymTypeFunction && isparam) {
            // special case: only a warning for param to mask global function
            warningf("%s", e);
        } else {
            errorf("%s", e);
        }
    }
    if (OIIO::Strutil::starts_with (name, "___")) {
        errorf("\"%s\" : sorry, can't start with three underscores", name);
    }
    SymType symtype = isparam ? (isoutput ? SymTypeOutputParam : SymTypeParam)
                              : SymTypeLocal;
    // Sneaky debugging aid: a local that starts with "__debug_tmp__"
    // gets declared as a temp. Don't do this on purpose!!!
    if (symtype == SymTypeLocal && Strutil::starts_with (name, "__debug_tmp__"))
        symtype = SymTypeTemp;
    m_sym = new Symbol (name, type, symtype, this);
    if (! m_ismetadata)
        oslcompiler->symtab().insert (m_sym);

    // A struct really makes several subvariables
    if (type.is_structure() || type.is_structure_array()) {
        OSL_ASSERT (! m_ismetadata);
        // Add the fields as individual declarations
        m_compiler->add_struct_fields (type.structspec(), m_sym->name(), symtype,
                                       type.is_unsized_array() ? -1 : type.arraylength(),
                                       this, init);
    }
}



const char *
ASTvariable_declaration::nodetypename () const
{
    return m_isparam ? "parameter" : "variable_declaration";
}



const char *
ASTvariable_declaration::childname (size_t i) const
{
    static const char *name[] = { "initializer", "metadata" };
    return name[i];
}



void
ASTvariable_declaration::print (std::ostream &out, int indentlevel) const
{
    indent (out, indentlevel);
    out << "(" << nodetypename() << " "
              << m_sym->typespec().string() << " "
              << m_sym->mangled();
#if 0
    if (m_sym->scope())
        out << " (" << m_sym->name()
                  << " in scope " << m_sym->scope() << ")";
#endif
    out << "\n";
    printchildren (out, indentlevel);
    indent (out, indentlevel);
    out << ")\n";
}



ASTvariable_ref::ASTvariable_ref (OSLCompilerImpl *comp, ustring name)
    : ASTNode (variable_ref_node, comp), m_name(name), m_sym(NULL)
{
    m_sym = comp->symtab().find (name);
    if (! m_sym) {
        errorf("'%s' was not declared in this scope", name);
        // FIXME -- would be fun to troll through the symtab and try to
        // find the things that almost matched and offer suggestions.
        return;
    }
    if (m_sym->symtype() == SymTypeFunction) {
        errorf("function '%s' can't be used as a variable", name);
        return;
    }
    m_typespec = m_sym->typespec();
}



void
ASTvariable_ref::print (std::ostream &out, int indentlevel) const
{
    indent (out, indentlevel);
    out << "(" << nodetypename() << " (type: "
        << (m_sym ? m_sym->typespec().string() : "unknown") << ") "
        << (m_sym ? m_sym->mangled() : m_name.string()) << ")\n";
    OSL_DASSERT(nchildren() == 0);
}



ASTpreincdec::ASTpreincdec (OSLCompilerImpl *comp, int op, ASTNode *expr)
    : ASTNode (preincdec_node, comp, op, expr)
{
    check_symbol_writeability (expr);
}



const char *
ASTpreincdec::childname (size_t i) const
{
    static const char *name[] = { "expression" };
    return name[i];
}



ASTpostincdec::ASTpostincdec (OSLCompilerImpl *comp, int op, ASTNode *expr)
    : ASTNode (postincdec_node, comp, op, expr)
{
    check_symbol_writeability (expr);
}



const char *
ASTpostincdec::childname (size_t i) const
{
    static const char *name[] = { "expression" };
    return name[i];
}



ASTindex::ASTindex (OSLCompilerImpl *comp, ASTNode *expr, ASTNode *index,
          ASTNode *index2, ASTNode *index3)
    : ASTNode (index_node, comp, 0, expr, index /*NO: index2, index3*/)
{
    // We only initialized the first child. Add more if additional arguments
    // were supplied.
    OSL_DASSERT(index);
    if (index2)
        addchild(index2);
    if (index3)
        addchild(index3);

    // Special case: an ASTindex where the `expr` is itself an ASTindex.
    // This construction results from named-component access of array
    // elements, e.g., `colorarray[i].r`. In that case, what we want to do
    // is rearrange to turn this into the two-index variety and discard the
    // child index.
    if (!index2 && expr->nodetype() == index_node && expr->nchildren() == 2) {
        ref newexpr = static_cast<ASTindex*>(expr)->lvalue();
        ref newindex = static_cast<ASTindex*>(expr)->index();
        ref newindex2 = index;
        clearchildren();
        addchild(newexpr);      expr   = newexpr.get();
        addchild(newindex);     index  = newindex.get();
        addchild(newindex2);    index2 = newindex2.get();
    }

    OSL_DASSERT(expr->nodetype() == variable_ref_node ||
             expr->nodetype() == structselect_node);
    OSL_DASSERT(m_typespec.is_unknown());

    if (!index2) {
        // 1-argument: simple array a[i] or component dereference triple[c]
        if (expr->typespec().is_array())       // array dereference
            m_typespec = expr->typespec().elementtype();
        else if (!expr->typespec().is_closure() &&
                 expr->typespec().is_triple()) // component access
            m_typespec = TypeDesc::FLOAT;
    } else if (!index3) {
        // 2-argument: matrix dereference m[r][c], or component of a
        // triple array colorarray[i][c].
        if (expr->typespec().is_matrix())  // matrix component access
            m_typespec = TypeDesc::FLOAT;
        else if (expr->typespec().is_array() &&   // triplearray[][]
                 expr->typespec().elementtype().is_triple())
            m_typespec = TypeDesc::FLOAT;
    } else {
        // 3-argument: one component of an array of matrices
        // matrixarray[i][r][c]
        if (expr->typespec().is_array() &&   // matrixarray[][]
                 expr->typespec().elementtype().is_matrix())
            m_typespec = TypeDesc::FLOAT;
    }

    if (m_typespec.is_unknown()) {
        errorf("indexing into non-array or non-component type");
    }
}



const char *
ASTindex::childname (size_t i) const
{
    static const char *name[] = { "expression", "index", "index" };
    return name[i];
}



ASTstructselect::ASTstructselect (OSLCompilerImpl *comp, ASTNode *expr,
                                  ustring field)
    : ASTNode (structselect_node, comp, 0, expr), m_field(field),
      m_structid(-1), m_fieldid(-1), m_fieldname(field), m_fieldsym(NULL)
{
    m_fieldsym = find_fieldsym (m_structid, m_fieldid);
    if (m_fieldsym) {
        m_fieldname = m_fieldsym->name();
        m_typespec = m_fieldsym->typespec();
    } else if (m_compindex) {
        // It's a named component, like point.x
        m_typespec = OIIO::TypeFloat;  // These cases are always single floats
    }
}



/// Return the symbol pointer to the individual field that this
/// structselect represents; also set structid to the ID of the
/// structure type, and fieldid to the field index within the struct.
Symbol *
ASTstructselect::find_fieldsym (int &structid, int &fieldid)
{
    auto lv = lvalue().get();
    auto lvtype = lv->typespec();

    if (lvtype.is_color()
        && (fieldname() == "r" || fieldname() == "g" || fieldname() == "b")) {
        OSL_DASSERT(fieldid == -1 && !m_compindex);
        fieldid = fieldname() == "r" ? 0 : (fieldname() == "g" ? 1 : 2);
        m_compindex.reset(new ASTindex(m_compiler, lv,
                                       new ASTliteral(oslcompiler, fieldid)));
        m_is_lvalue = true;
        return nullptr;
    }
    else if (lvtype.is_vectriple()
         && (fieldname() == "x" || fieldname() == "y" || fieldname() == "z")) {
        OSL_DASSERT(fieldid == -1 && !m_compindex);
        fieldid = fieldname() == "x" ? 0 : (fieldname() == "y" ? 1 : 2);
        m_compindex.reset(new ASTindex(m_compiler, lv,
                                       new ASTliteral(oslcompiler, fieldid)));
        m_is_lvalue = true;
        return nullptr;
    }

    if (! lvtype.is_structure() && ! lvtype.is_structure_array()) {
        errorf("type '%s' does not have a member '%s'", lvtype, m_field);
        return NULL;
    }

    ustring structsymname;
    TypeSpec structtype;
    find_structsym (lvalue().get(), structsymname, structtype);

    structid = structtype.structure();
    StructSpec *structspec (structtype.structspec());
    fieldid = -1;
    for (int i = 0;  i < (int)structspec->numfields();  ++i) {
        if (structspec->field(i).name == m_field) {
            fieldid = i;
            break;
        }
    }

    if (fieldid < 0) {
        errorf("struct type '%s' does not have a member '%s'",
               structspec->name(), m_field);
        return NULL;
    }

    const StructSpec::FieldSpec &fieldrec (structspec->field(fieldid));
    ustring fieldsymname = ustring::sprintf ("%s.%s", structsymname,
                                            fieldrec.name);
    Symbol *sym = m_compiler->symtab().find (fieldsymname);
    return sym;
}




/// structnode is an AST node representing a struct.  It could be a
/// struct variable, or a field of a struct (which is itself a struct),
/// or an array element of a struct.  Whatever, here we figure out some
/// vital information about it: the name of the symbol representing the
/// struct, and its type.
void
ASTstructselect::find_structsym (ASTNode *structnode, ustring &structname,
                                 TypeSpec &structtype)
{
    // This node selects a field from a struct. The purpose of this
    // method is to "flatten" the possibly-nested (struct in struct, and
    // or array of structs) down to a symbol that represents the
    // particular field.  In the process, we set structname and its
    // type structtype.
    OSL_DASSERT (structnode->typespec().is_structure() ||
                 structnode->typespec().is_structure_array());
    if (structnode->nodetype() == variable_ref_node) {
        // The structnode is a top-level struct variable
        ASTvariable_ref *var = (ASTvariable_ref *) structnode;
        structname = var->name();
        structtype = var->typespec();
    }
    else if (structnode->nodetype() == structselect_node) {
        // The structnode is itself a field of another struct.
        ASTstructselect *thestruct = (ASTstructselect *) structnode;
        int structid, fieldid;
        Symbol *sym = thestruct->find_fieldsym (structid, fieldid);
        structname = sym->name();
        structtype = sym->typespec();
    }
    else if (structnode->nodetype() == index_node) {
        // The structnode is an element of an array of structs:
        ASTindex *arrayref = (ASTindex *) structnode;
        find_structsym (arrayref->lvalue().get(), structname, structtype);
        structtype.make_array (0);  // clear its arrayness
    }
    else {
        OSL_ASSERT (0 && "Malformed ASTstructselect");
    }
}



const char *
ASTstructselect::childname (size_t i) const
{
    static const char *name[] = { "structure" };
    return name[i];
}



void
ASTstructselect::print (std::ostream &out, int indentlevel) const
{
    ASTNode::print (out, indentlevel);
    indent (out, indentlevel+1);
    out << "select " << field() << "\n";
}



const char *
ASTconditional_statement::childname (size_t i) const
{
    static const char *name[] = { "condition",
                                  "truestatement", "falsestatement" };
    return name[i];
}



ASTloop_statement::ASTloop_statement (OSLCompilerImpl *comp, LoopType looptype,
                                      ASTNode *init, ASTNode *cond,
                                      ASTNode *iter, ASTNode *stmt)
    : ASTNode (loop_statement_node, comp, looptype, init, cond, iter, stmt)
{
    // Handle empty comparison, for(;;), is same as for(;1;)
    if (!cond)
        m_children[1] = new ASTliteral(comp, 1);
}



const char *
ASTloop_statement::childname (size_t i) const
{
    static const char *name[] = { "initializer", "condition",
                                  "iteration", "bodystatement" };
    return name[i];
}



const char *
ASTloop_statement::opname () const
{
    switch (m_op) {
    case LoopWhile : return "while";
    case LoopDo    : return "dowhile";
    case LoopFor   : return "for";
    default:
        OSL_ASSERT(0 && "unknown loop type");
        return "unknown";
    }
}



const char *
ASTloopmod_statement::childname (size_t /*i*/) const
{
    return NULL;  // no children
}



const char *
ASTloopmod_statement::opname () const
{
    switch (m_op) {
    case LoopModBreak    : return "break";
    case LoopModContinue : return "continue";
    default:
        OSL_ASSERT(0 && "unknown loop modifier");
        return "unknown";
    }
}



const char *
ASTreturn_statement::childname (size_t /*i*/) const
{
    return "expression";  // only child
}



ASTcompound_initializer::ASTcompound_initializer (OSLCompilerImpl *comp,
                                                  ASTNode *exprlist)
    : ASTtype_constructor (compound_initializer_node, comp, TypeSpec(), exprlist),
      m_ctor(false)
{
}



const char *
ASTcompound_initializer::childname (size_t /*i*/) const
{
    return canconstruct() ? "args" : "expression_list";
}



bool
ASTNode::check_symbol_writeability (ASTNode *var)
{
    if (var->nodetype() == index_node)
        return check_symbol_writeability (static_cast<ASTindex*>(var)->lvalue().get());
    if (var->nodetype() == structselect_node)
        return check_symbol_writeability (static_cast<ASTstructselect*>(var)->lvalue().get());

    Symbol *dest = nullptr;
    if (var->nodetype() == variable_ref_node)
        dest = static_cast<ASTvariable_ref*>(var)->sym();
    else if (var->nodetype() == variable_declaration_node)
        dest = static_cast<ASTvariable_declaration*>(var)->sym();

    if (dest) {
        if (dest->readonly()) {
            warningf("cannot write to non-output parameter \"%s\"", dest->name());
            // Note: Consider it only a warning to write to a non-output
            // parameter. Users who want it to be a hard error can use
            // -Werror. Writing to any other readonly symbols is a full
            // error.
            return false;
        }
    } else {
        // std::cout << "Don't know how to check_symbol_writeability "
        //           << var->nodetypename() << "\n";
    }
    return true;
}



ASTassign_expression::ASTassign_expression (OSLCompilerImpl *comp, ASTNode *var,
                                            Operator op, ASTNode *expr)
    : ASTNode (assign_expression_node, comp, op, var, expr)
{
    if (op != Assign) {
        // Rejigger to straight assignment and binary op
        m_op = Assign;
        m_children[1] = new ASTbinary_expression (comp, op, var, expr);
    }

    check_symbol_writeability (var);
}



const char *
ASTassign_expression::childname (size_t i) const
{
    static const char *name[] = { "variable", "expression" };
    return name[i];
}



const char *
ASTassign_expression::opname () const
{
    switch (m_op) {
    case Assign     : return "=";
    case Mul        : return "*=";
    case Div        : return "/=";
    case Add        : return "+=";
    case Sub        : return "-=";
    case BitAnd     : return "&=";
    case BitOr      : return "|=";
    case Xor        : return "^=";
    case ShiftLeft  : return "<<=";
    case ShiftRight : return ">>=";
    default:
        OSL_ASSERT (0 && "unknown assignment expression");
        return "="; // punt
    }
}



const char *
ASTassign_expression::opword () const
{
    switch (m_op) {
    case Assign     : return "assign";
    case Mul        : return "mul";
    case Div        : return "div";
    case Add        : return "add";
    case Sub        : return "sub";
    case BitAnd     : return "bitand";
    case BitOr      : return "bitor";
    case Xor        : return "xor";
    case ShiftLeft  : return "shl";
    case ShiftRight : return "shr";
    default:
        OSL_ASSERT (0 && "unknown assignment expression");
        return "assign"; // punt
    }
}



ASTunary_expression::ASTunary_expression (OSLCompilerImpl *comp, int op,
                                          ASTNode *expr)
    : ASTNode (unary_expression_node, comp, op, expr)
{
    // Check for a user-overloaded function for this operator
    Symbol *sym = comp->symtab().find (ustring::sprintf ("__operator__%s__", opword()));
    if (sym && sym->symtype() == SymTypeFunction)
        m_function_overload = (FunctionSymbol *)sym;
}



const char *
ASTunary_expression::childname (size_t i) const
{
    static const char *name[] = { "expression" };
    return name[i];
}



const char *
ASTunary_expression::opname () const
{
    switch (m_op) {
    case Add   : return "+";
    case Sub   : return "-";
    case Not   : return "!";
    case Compl : return "~";
    default:
        OSL_ASSERT (0 && "unknown unary expression");
        return "unknown";
    }
}



const char *
ASTunary_expression::opword () const
{
    switch (m_op) {
    case Add   : return "add";
    case Sub   : return "neg";
    case Not   : return "not";
    case Compl : return "compl";
    default:
        OSL_ASSERT (0 && "unknown unary expression");
        return "unknown";
    }
}



ASTbinary_expression::ASTbinary_expression (OSLCompilerImpl *comp, Operator op,
                                            ASTNode *left, ASTNode *right)
    : ASTNode (binary_expression_node, comp, op, left, right)
{
    // Check for a user-overloaded function for this operator.
    // Disallow a few ops from overloading.
    if (op != And && op != Or) {
        ustring funcname = ustring::sprintf ("__operator__%s__", opword());
        Symbol *sym = comp->symtab().find (funcname);
        if (sym && sym->symtype() == SymTypeFunction)
            m_function_overload = (FunctionSymbol *)sym;
    }
}



const char *
ASTbinary_expression::childname (size_t i) const
{
    static const char *name[] = { "left", "right" };
    return name[i];
}



const char *
ASTbinary_expression::opname () const
{
    switch (m_op) {
    case Mul          : return "*";
    case Div          : return "/";
    case Add          : return "+";
    case Sub          : return "-";
    case Mod          : return "%";
    case Equal        : return "==";
    case NotEqual     : return "!=";
    case Greater      : return ">";
    case GreaterEqual : return ">=";
    case Less         : return "<";
    case LessEqual    : return "<=";
    case BitAnd       : return "&";
    case BitOr        : return "|";
    case Xor          : return "^";
    case And          : return "&&";
    case Or           : return "||";
    case ShiftLeft    : return "<<";
    case ShiftRight   : return ">>";
    default:
        OSL_ASSERT (0 && "unknown binary expression");
        return "unknown";
    }
}



const char *
ASTbinary_expression::opword () const
{
    switch (m_op) {
    case Mul          : return "mul";
    case Div          : return "div";
    case Add          : return "add";
    case Sub          : return "sub";
    case Mod          : return "mod";
    case Equal        : return "eq";
    case NotEqual     : return "neq";
    case Greater      : return "gt";
    case GreaterEqual : return "ge";
    case Less         : return "lt";
    case LessEqual    : return "le";
    case BitAnd       : return "bitand";
    case BitOr        : return "bitor";
    case Xor          : return "xor";
    case And          : return "and";
    case Or           : return "or";
    case ShiftLeft    : return "shl";
    case ShiftRight   : return "shr";
    default:
        OSL_ASSERT (0 && "unknown binary expression");
        return "unknown";
    }
}



const char *
ASTternary_expression::childname (size_t i) const
{
    static const char *name[] = { "condition",
                                  "trueexpression", "falseexpression" };
    return name[i];
}



const char *
ASTtypecast_expression::childname (size_t i) const
{
    static const char *name[] = { "expr" };
    return name[i];
}



const char *
ASTtype_constructor::childname (size_t i) const
{
    static const char *name[] = { "args" };
    return name[i];
}



ASTfunction_call::ASTfunction_call (OSLCompilerImpl *comp, ustring name,
                                    ASTNode *args, FunctionSymbol *funcsym)
    : ASTNode (function_call_node, comp, 0, args), m_name(name),
      m_sym(funcsym ? funcsym : comp->symtab().find (name)), // Look it up.
      m_poly(funcsym),    // Default - resolved symbol or null
      m_argread(~1),      // Default - all args are read except the first
      m_argwrite(1),      // Default - first arg only is written by the op
      m_argtakesderivs(0) // Default - doesn't take derivs
{
    if (! m_sym) {
        errorf("function '%s' was not declared in this scope", name);
        // FIXME -- would be fun to troll through the symtab and try to
        // find the things that almost matched and offer suggestions.
        return;
    }
    if (is_struct_ctr()) {
        return;  // It's a struct constructor
    }
    if (m_sym->symtype() != SymTypeFunction) {
        errorf("'%s' is not a function", name);
        m_sym = NULL;
        return;
    }
}



const char *
ASTfunction_call::childname (size_t i) const
{
    return ustring::sprintf ("param%d", (int)i).c_str();
}



const char *
ASTfunction_call::opname () const
{
    return m_name.c_str ();
}



void
ASTfunction_call::print (std::ostream &out, int indentlevel) const
{
    ASTNode::print (out, indentlevel);
#if 0
    if (is_user_function()) {
        out << "\n";
        user_function()->print (out, indentlevel+1);
        out << "\n";
    }
#endif
}



const char *
ASTliteral::childname (size_t /*i*/) const
{
    return NULL;
}



void
ASTliteral::print (std::ostream &out, int indentlevel) const
{
    indent (out, indentlevel);
    out << "(" << nodetypename() << " (type: " << m_typespec.string() << ") ";
    if (m_typespec.is_int())
        out << m_i;
    else if (m_typespec.is_float())
        out << m_f;
    else if (m_typespec.is_string())
        out << "\"" << m_s << "\"";
    out << ")\n";
}


}; // namespace pvt

OSL_NAMESPACE_EXIT
